Open In Colab   Open in Kaggle

Bonus Tutorial 4: The Kalman Filter, part 2

Week 3, Day 2: Hidden Dynamics

By Neuromatch Academy

Content creators: Caroline Haimerl and Byron Galbraith

Content reviewers: Jesse Livezey, Matt Krause, Michael Waskom, and Xaq Pitkow

Post-production team: Gagana B, Spiros Chavlis


Important note: This is bonus material, included from NMA 2020. It has not been substantially revised for 2021. This means that the notation and standards are slightly different. We include it here because it provides additional information about how the Kalman filter works in two dimensions.


Useful references:

  • Roweis, Ghahramani (1998): A unifying review of linear Gaussian Models

  • Bishop (2006): Pattern Recognition and Machine Learning


Acknowledgements:

This tutorial is in part based on code originally created by Caroline Haimerl for Dr. Cristina Savin’s Probabilistic Time Series class at the Center for Data Science, New York University

Video 1: Introduction

Video available at https://youtu.be/6f_51L3i5aQ

Tutorial Objectives

In the previous tutorial we gained intuition for the Kalman filter in one dimension. In this tutorial, we will examine the two-dimensional Kalman filter and more of its mathematical foundations.

In this tutorial, you will:

  • Review linear dynamical systems

  • Implement the Kalman filter

  • Explore how the Kalman filter can be used to smooth data from an eye-tracking experiment

import sys
!conda install -c conda-forge ipywidgets --yes
Collecting package metadata (current_repodata.json): - 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
done
Solving environment: \ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
done
==> WARNING: A newer version of conda exists. <==
  current version: 23.3.1
  latest version: 23.5.0

Please update conda by running

    $ conda update -n base -c defaults conda

Or to minimize the number of packages updated during conda update use

     conda install conda=23.5.0
## Package Plan ##

  environment location: /usr/share/miniconda

  added / updated specs:
    - ipywidgets


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    asttokens-2.2.1            |     pyhd8ed1ab_0          27 KB  conda-forge
    backcall-0.2.0             |     pyh9f0ad1d_0          13 KB  conda-forge
    backports-1.0              |     pyhd8ed1ab_3           6 KB  conda-forge
    backports.functools_lru_cache-1.6.4|     pyhd8ed1ab_0           9 KB  conda-forge
    ca-certificates-2023.5.7   |       hbcca054_0         145 KB  conda-forge
    certifi-2023.5.7           |     pyhd8ed1ab_0         149 KB  conda-forge
    conda-23.5.0               |  py310hff52083_1        1006 KB  conda-forge
    debugpy-1.5.1              |  py310h295c915_0         1.7 MB
    decorator-5.1.1            |     pyhd8ed1ab_0          12 KB  conda-forge
    entrypoints-0.4            |     pyhd8ed1ab_0           9 KB  conda-forge
    executing-1.2.0            |     pyhd8ed1ab_0          24 KB  conda-forge
    ipykernel-6.15.0           |     pyh210e3f2_0          96 KB  conda-forge
    ipython-8.14.0             |     pyh41d4057_0         570 KB  conda-forge
    ipywidgets-8.0.6           |     pyhd8ed1ab_0         110 KB  conda-forge
    jedi-0.18.2                |     pyhd8ed1ab_0         786 KB  conda-forge
    jupyter_client-7.3.4       |     pyhd8ed1ab_0          91 KB  conda-forge
    jupyter_core-5.3.0         |  py310hff52083_0          88 KB  conda-forge
    jupyterlab_widgets-3.0.7   |     pyhd8ed1ab_1         169 KB  conda-forge
    libsodium-1.0.18           |       h36c2ea0_1         366 KB  conda-forge
    matplotlib-inline-0.1.6    |     pyhd8ed1ab_0          12 KB  conda-forge
    nest-asyncio-1.5.6         |     pyhd8ed1ab_0          10 KB  conda-forge
    parso-0.8.3                |     pyhd8ed1ab_0          69 KB  conda-forge
    pexpect-4.8.0              |     pyh1a96a4e_2          48 KB  conda-forge
    pickleshare-0.7.5          |          py_1003           9 KB  conda-forge
    platformdirs-3.5.3         |     pyhd8ed1ab_0          18 KB  conda-forge
    prompt-toolkit-3.0.38      |     pyha770c72_0         263 KB  conda-forge
    prompt_toolkit-3.0.38      |       hd8ed1ab_0           6 KB  conda-forge
    psutil-5.9.0               |  py310h5eee18b_0         368 KB
    ptyprocess-0.7.0           |     pyhd3deb0d_0          16 KB  conda-forge
    pure_eval-0.2.2            |     pyhd8ed1ab_0          14 KB  conda-forge
    pygments-2.15.1            |     pyhd8ed1ab_0         821 KB  conda-forge
    python-dateutil-2.8.2      |     pyhd8ed1ab_0         240 KB  conda-forge
    python_abi-3.10            |          2_cp310           4 KB  conda-forge
    pyzmq-25.1.0               |  py310h6a678d5_0         462 KB
    stack_data-0.6.2           |     pyhd8ed1ab_0          26 KB  conda-forge
    tornado-6.1                |  py310h5764c6d_3         657 KB  conda-forge
    traitlets-5.9.0            |     pyhd8ed1ab_0          96 KB  conda-forge
    typing-extensions-4.6.3    |       hd8ed1ab_0          10 KB  conda-forge
    typing_extensions-4.6.3    |     pyha770c72_0          34 KB  conda-forge
    wcwidth-0.2.6              |     pyhd8ed1ab_0          28 KB  conda-forge
    widgetsnbextension-4.0.7   |     pyhd8ed1ab_0         828 KB  conda-forge
    zeromq-4.3.4               |       h9c3ff4c_1         351 KB  conda-forge
    ------------------------------------------------------------
                                           Total:         9.6 MB

The following NEW packages will be INSTALLED:

  asttokens          conda-forge/noarch::asttokens-2.2.1-pyhd8ed1ab_0 
  backcall           conda-forge/noarch::backcall-0.2.0-pyh9f0ad1d_0 
  backports          conda-forge/noarch::backports-1.0-pyhd8ed1ab_3 
  backports.functoo~ conda-forge/noarch::backports.functools_lru_cache-1.6.4-pyhd8ed1ab_0 
  debugpy            pkgs/main/linux-64::debugpy-1.5.1-py310h295c915_0 
  decorator          conda-forge/noarch::decorator-5.1.1-pyhd8ed1ab_0 
  entrypoints        conda-forge/noarch::entrypoints-0.4-pyhd8ed1ab_0 
  executing          conda-forge/noarch::executing-1.2.0-pyhd8ed1ab_0 
  ipykernel          conda-forge/noarch::ipykernel-6.15.0-pyh210e3f2_0 
  ipython            conda-forge/noarch::ipython-8.14.0-pyh41d4057_0 
  ipywidgets         conda-forge/noarch::ipywidgets-8.0.6-pyhd8ed1ab_0 
  jedi               conda-forge/noarch::jedi-0.18.2-pyhd8ed1ab_0 
  jupyter_client     conda-forge/noarch::jupyter_client-7.3.4-pyhd8ed1ab_0 
  jupyter_core       conda-forge/linux-64::jupyter_core-5.3.0-py310hff52083_0 
  jupyterlab_widgets conda-forge/noarch::jupyterlab_widgets-3.0.7-pyhd8ed1ab_1 
  libsodium          conda-forge/linux-64::libsodium-1.0.18-h36c2ea0_1 
  matplotlib-inline  conda-forge/noarch::matplotlib-inline-0.1.6-pyhd8ed1ab_0 
  nest-asyncio       conda-forge/noarch::nest-asyncio-1.5.6-pyhd8ed1ab_0 
  parso              conda-forge/noarch::parso-0.8.3-pyhd8ed1ab_0 
  pexpect            conda-forge/noarch::pexpect-4.8.0-pyh1a96a4e_2 
  pickleshare        conda-forge/noarch::pickleshare-0.7.5-py_1003 
  platformdirs       conda-forge/noarch::platformdirs-3.5.3-pyhd8ed1ab_0 
  prompt-toolkit     conda-forge/noarch::prompt-toolkit-3.0.38-pyha770c72_0 
  prompt_toolkit     conda-forge/noarch::prompt_toolkit-3.0.38-hd8ed1ab_0 
  psutil             pkgs/main/linux-64::psutil-5.9.0-py310h5eee18b_0 
  ptyprocess         conda-forge/noarch::ptyprocess-0.7.0-pyhd3deb0d_0 
  pure_eval          conda-forge/noarch::pure_eval-0.2.2-pyhd8ed1ab_0 
  pygments           conda-forge/noarch::pygments-2.15.1-pyhd8ed1ab_0 
  python-dateutil    conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 
  python_abi         conda-forge/linux-64::python_abi-3.10-2_cp310 
  pyzmq              pkgs/main/linux-64::pyzmq-25.1.0-py310h6a678d5_0 
  stack_data         conda-forge/noarch::stack_data-0.6.2-pyhd8ed1ab_0 
  tornado            conda-forge/linux-64::tornado-6.1-py310h5764c6d_3 
  traitlets          conda-forge/noarch::traitlets-5.9.0-pyhd8ed1ab_0 
  typing-extensions  conda-forge/noarch::typing-extensions-4.6.3-hd8ed1ab_0 
  typing_extensions  conda-forge/noarch::typing_extensions-4.6.3-pyha770c72_0 
  wcwidth            conda-forge/noarch::wcwidth-0.2.6-pyhd8ed1ab_0 
  widgetsnbextension conda-forge/noarch::widgetsnbextension-4.0.7-pyhd8ed1ab_0 
  zeromq             conda-forge/linux-64::zeromq-4.3.4-h9c3ff4c_1 

The following packages will be UPDATED:

  ca-certificates    pkgs/main::ca-certificates-2023.01.10~ --> conda-forge::ca-certificates-2023.5.7-hbcca054_0 
  certifi            pkgs/main/linux-64::certifi-2022.12.7~ --> conda-forge/noarch::certifi-2023.5.7-pyhd8ed1ab_0 
  conda              pkgs/main::conda-23.3.1-py310h06a4308~ --> conda-forge::conda-23.5.0-py310hff52083_1 



Downloading and Extracting Packages

python_abi-3.10      | 4 KB      |                                       |   0% 

ptyprocess-0.7.0     | 16 KB     |                                       |   0% 


backports.functools_ | 9 KB      |                                       |   0% 



tornado-6.1          | 657 KB    |                                       |   0% 




jupyter_client-7.3.4 | 91 KB     |                                       |   0% 





matplotlib-inline-0. | 12 KB     |                                       |   0% 






nest-asyncio-1.5.6   | 10 KB     |                                       |   0% 







backcall-0.2.0       | 13 KB     |                                       |   0% 








entrypoints-0.4      | 9 KB      |                                       |   0% 









pygments-2.15.1      | 821 KB    |                                       |   0% 










zeromq-4.3.4         | 351 KB    |                                       |   0% 











asttokens-2.2.1      | 27 KB     |                                       |   0% 












traitlets-5.9.0      | 96 KB     |                                       |   0% 













typing_extensions-4. | 34 KB     |                                       |   0% 














platformdirs-3.5.3   | 18 KB     |                                       |   0% 















python-dateutil-2.8. | 240 KB    |                                       |   0% 
















ipykernel-6.15.0     | 96 KB     |                                       |   0% 

















jupyter_core-5.3.0   | 88 KB     |                                       |   0% 


















backports-1.0        | 6 KB      |                                       |   0% 



















psutil-5.9.0         | 368 KB    |                                       |   0% 




















jedi-0.18.2          | 786 KB    |                                       |   0% 





















pickleshare-0.7.5    | 9 KB      |                                       |   0% 






















ipython-8.14.0       | 570 KB    |                                       |   0% 























 ... (more hidden) ...
python_abi-3.10      | 4 KB      | ##################################### | 100% 
tornado-6.1          | 657 KB    | ###6                                  |  10% 




jupyter_client-7.3.4 | 91 KB     | ###################5                  |  53% 

ptyprocess-0.7.0     | 16 KB     | ##################################### | 100% 

ptyprocess-0.7.0     | 16 KB     | ##################################### | 100% 







backcall-0.2.0       | 13 KB     | ##################################### | 100% 






nest-asyncio-1.5.6   | 10 KB     | ##################################### | 100% 








entrypoints-0.4      | 9 KB      | ##################################### | 100% 
matplotlib-inline-0. | 12 KB     | ##################################### | 100% 





matplotlib-inline-0. | 12 KB     | ##################################### | 100% 









pygments-2.15.1      | 821 KB    | 7                                     |   2% 


backports.functools_ | 9 KB      | ##################################### | 100% 


backports.functools_ | 9 KB      | ##################################### | 100% 










zeromq-4.3.4         | 351 KB    | #6                                    |   5% 












traitlets-5.9.0      | 96 KB     | ######1                               |  17% 











asttokens-2.2.1      | 27 KB     | #####################7                |  59% 







backcall-0.2.0       | 13 KB     | ##################################### | 100% 
typing_extensions-4. | 34 KB     | #################3                    |  47% 














platformdirs-3.5.3   | 18 KB     | ################################2     |  87% 















python-dateutil-2.8. | 240 KB    | ##4                                   |   7% 






nest-asyncio-1.5.6   | 10 KB     | ##################################### | 100% 
















ipykernel-6.15.0     | 96 KB     | ######1                               |  17% 

















jupyter_core-5.3.0   | 88 KB     | ######6                               |  18% 


















backports-1.0        | 6 KB      | ##################################### | 100% 
entrypoints-0.4      | 9 KB      | ##################################### | 100% 




















jedi-0.18.2          | 786 KB    | 7                                     |   2% 





















pickleshare-0.7.5    | 9 KB      | ##################################### | 100% 



















psutil-5.9.0         | 368 KB    | #6                                    |   4% 























 ... (more hidden) ...






















ipython-8.14.0       | 570 KB    | #                                     |   3% 




jupyter_client-7.3.4 | 91 KB     | ##################################### | 100% 
asttokens-2.2.1      | 27 KB     | ##################################### | 100% 











asttokens-2.2.1      | 27 KB     | ##################################### | 100% 
traitlets-5.9.0      | 96 KB     | ##################################### | 100% 












traitlets-5.9.0      | 96 KB     | ##################################### | 100% 
zeromq-4.3.4         | 351 KB    | ##################################### | 100% 










zeromq-4.3.4         | 351 KB    | ##################################### | 100% 
tornado-6.1          | 657 KB    | ##################################### | 100% 



tornado-6.1          | 657 KB    | ##################################### | 100% 
typing_extensions-4. | 34 KB     | ##################################### | 100% 













typing_extensions-4. | 34 KB     | ##################################### | 100% 









pygments-2.15.1      | 821 KB    | ##################################### | 100% 









pygments-2.15.1      | 821 KB    | ##################################### | 100% 














platformdirs-3.5.3   | 18 KB     | ##################################### | 100% 
python-dateutil-2.8. | 240 KB    | ##################################### | 100% 















python-dateutil-2.8. | 240 KB    | ##################################### | 100% 
ipykernel-6.15.0     | 96 KB     | ##################################### | 100% 
















ipykernel-6.15.0     | 96 KB     | ##################################### | 100% 


















backports-1.0        | 6 KB      | ##################################### | 100% 





















pickleshare-0.7.5    | 9 KB      | ##################################### | 100% 

















jupyter_core-5.3.0   | 88 KB     | ##################################### | 100% 

















jupyter_core-5.3.0   | 88 KB     | ##################################### | 100% 
 ... (more hidden) ...























 ... (more hidden) ...
ipython-8.14.0       | 570 KB    | ##################################### | 100% 






















ipython-8.14.0       | 570 KB    | ##################################### | 100% 
psutil-5.9.0         | 368 KB    | ##################################### | 100% 



















psutil-5.9.0         | 368 KB    | ##################################### | 100% 
jedi-0.18.2          | 786 KB    | ##################################### | 100% 




















jedi-0.18.2          | 786 KB    | ##################################### | 100% 
                      

                                                                                


                                                                                


                                                                                



                                                                                




                                                                                





                                                                                






                                                                                







                                                                                








                                                                                









                                                                                










                                                                                











                                                                                












                                                                                













                                                                                














                                                                                















                                                                                
















                                                                                

















                                                                                


















                                                                                



















                                                                                




















                                                                                





















                                                                                






















                                                                                
































































































































































































Preparing transaction: | 
/ 
- 
done
Verifying transaction: 
| 
/ 
- 
\ 
| 
done
Executing transaction: - 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
done
!conda install numpy matplotlib scipy requests --yes
Collecting package metadata (current_repodata.json): - 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
done
Solving environment: | 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
done
## Package Plan ##

  environment location: /usr/share/miniconda

  added / updated specs:
    - matplotlib
    - numpy
    - requests
    - scipy


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    appdirs-1.4.4              |     pyhd3eb1b0_0          12 KB
    blas-1.0                   |              mkl           6 KB
    brotli-1.0.9               |       h5eee18b_7          18 KB
    brotli-bin-1.0.9           |       h5eee18b_7          19 KB
    ca-certificates-2023.05.30 |       h06a4308_0         120 KB
    certifi-2023.5.7           |  py310h06a4308_0         152 KB
    conda-23.5.0               |  py310h06a4308_0         1.0 MB
    contourpy-1.0.5            |  py310hdb19cb5_0         204 KB
    cycler-0.11.0              |     pyhd3eb1b0_0          12 KB
    dbus-1.13.18               |       hb2f20db_0         504 KB
    expat-2.4.9                |       h6a678d5_0         156 KB
    fontconfig-2.14.1          |       h4c34cd2_2         281 KB
    fonttools-4.25.0           |     pyhd3eb1b0_0         632 KB
    freetype-2.12.1            |       h4a9f257_0         626 KB
    giflib-5.2.1               |       h5eee18b_3          80 KB
    glib-2.69.1                |       he621ea3_2         1.9 MB
    gst-plugins-base-1.14.1    |       h6a678d5_1         2.2 MB
    gstreamer-1.14.1           |       h5eee18b_1         1.7 MB
    icu-58.2                   |       he6710b0_3        10.5 MB
    intel-openmp-2023.1.0      |   hdb19cb5_46305        17.1 MB
    jpeg-9e                    |       h5eee18b_1         262 KB
    kiwisolver-1.4.4           |  py310h6a678d5_0          76 KB
    krb5-1.19.4                |       h568e23c_0         1.3 MB
    lcms2-2.12                 |       h3be6417_0         312 KB
    lerc-3.0                   |       h295c915_0         196 KB
    libbrotlicommon-1.0.9      |       h5eee18b_7          70 KB
    libbrotlidec-1.0.9         |       h5eee18b_7          31 KB
    libbrotlienc-1.0.9         |       h5eee18b_7         264 KB
    libclang-14.0.6            |default_hc6dbbc7_1         137 KB
    libclang13-14.0.6          |default_he11475f_1         9.8 MB
    libdeflate-1.17            |       h5eee18b_0          69 KB
    libedit-3.1.20221030       |       h5eee18b_0         181 KB
    libevent-2.1.12            |       h8f2d780_0         425 KB
    libgfortran-ng-11.2.0      |       h00389a5_1          20 KB
    libgfortran5-11.2.0        |       h1234567_1         2.0 MB
    libllvm14-14.0.6           |       hdb19cb5_3        33.4 MB
    libpng-1.6.39              |       h5eee18b_0         304 KB
    libpq-12.9                 |       h16c4e8d_3         2.1 MB
    libtiff-4.5.0              |       h6a678d5_2         479 KB
    libwebp-1.2.4              |       h11a3e52_1          86 KB
    libwebp-base-1.2.4         |       h5eee18b_1         376 KB
    libxcb-1.15                |       h7f8727e_0         505 KB
    libxkbcommon-1.0.1         |       h5eee18b_1         590 KB
    libxml2-2.10.3             |       hcbfbd50_0         755 KB
    libxslt-1.1.37             |       h2085143_0         266 KB
    lz4-c-1.9.4                |       h6a678d5_0         154 KB
    matplotlib-3.7.1           |  py310h06a4308_1           8 KB
    matplotlib-base-3.7.1      |  py310h1128e8f_1         6.7 MB
    mkl-2023.1.0               |   h6d00ec8_46342       171.5 MB
    mkl-service-2.4.0          |  py310h5eee18b_1          54 KB
    mkl_fft-1.3.6              |  py310h1128e8f_1         207 KB
    mkl_random-1.2.2           |  py310h1128e8f_1         284 KB
    munkres-1.1.4              |             py_0          13 KB
    nspr-4.35                  |       h6a678d5_0         244 KB
    nss-3.89.1                 |       h6a678d5_0         2.1 MB
    numpy-1.24.3               |  py310h5f9d8c6_1          11 KB
    numpy-base-1.24.3          |  py310hb5e798b_1         6.2 MB
    pcre-8.45                  |       h295c915_0         207 KB
    pillow-9.4.0               |  py310h6a678d5_0         730 KB
    ply-3.11                   |  py310h06a4308_0          80 KB
    pooch-1.4.0                |     pyhd3eb1b0_0          41 KB
    pyparsing-3.0.9            |  py310h06a4308_0         153 KB
    pyqt-5.15.7                |  py310h6a678d5_1         5.1 MB
    pyqt5-sip-12.11.0          |  py310h6a678d5_1         277 KB
    qt-main-5.15.2             |       h8373d8f_8        53.7 MB
    qt-webengine-5.15.9        |       hbbf29b9_6        49.2 MB
    qtwebkit-5.212             |       h3fafdc1_5        16.2 MB
    requests-2.29.0            |  py310h06a4308_0          97 KB
    scipy-1.10.1               |  py310h5f9d8c6_1        23.7 MB
    sip-6.6.2                  |  py310h6a678d5_0         692 KB
    sqlite-3.41.2              |       h5eee18b_0         1.2 MB
    tbb-2021.8.0               |       hdb19cb5_0         1.6 MB
    toml-0.10.2                |     pyhd3eb1b0_0          20 KB
    zstd-1.5.5                 |       hc292b87_0         647 KB
    ------------------------------------------------------------
                                           Total:       432.1 MB

The following NEW packages will be INSTALLED:

  appdirs            pkgs/main/noarch::appdirs-1.4.4-pyhd3eb1b0_0 
  blas               pkgs/main/linux-64::blas-1.0-mkl 
  brotli             pkgs/main/linux-64::brotli-1.0.9-h5eee18b_7 
  brotli-bin         pkgs/main/linux-64::brotli-bin-1.0.9-h5eee18b_7 
  contourpy          pkgs/main/linux-64::contourpy-1.0.5-py310hdb19cb5_0 
  cycler             pkgs/main/noarch::cycler-0.11.0-pyhd3eb1b0_0 
  dbus               pkgs/main/linux-64::dbus-1.13.18-hb2f20db_0 
  expat              pkgs/main/linux-64::expat-2.4.9-h6a678d5_0 
  fontconfig         pkgs/main/linux-64::fontconfig-2.14.1-h4c34cd2_2 
  fonttools          pkgs/main/noarch::fonttools-4.25.0-pyhd3eb1b0_0 
  freetype           pkgs/main/linux-64::freetype-2.12.1-h4a9f257_0 
  giflib             pkgs/main/linux-64::giflib-5.2.1-h5eee18b_3 
  glib               pkgs/main/linux-64::glib-2.69.1-he621ea3_2 
  gst-plugins-base   pkgs/main/linux-64::gst-plugins-base-1.14.1-h6a678d5_1 
  gstreamer          pkgs/main/linux-64::gstreamer-1.14.1-h5eee18b_1 
  icu                pkgs/main/linux-64::icu-58.2-he6710b0_3 
  intel-openmp       pkgs/main/linux-64::intel-openmp-2023.1.0-hdb19cb5_46305 
  jpeg               pkgs/main/linux-64::jpeg-9e-h5eee18b_1 
  kiwisolver         pkgs/main/linux-64::kiwisolver-1.4.4-py310h6a678d5_0 
  krb5               pkgs/main/linux-64::krb5-1.19.4-h568e23c_0 
  lcms2              pkgs/main/linux-64::lcms2-2.12-h3be6417_0 
  lerc               pkgs/main/linux-64::lerc-3.0-h295c915_0 
  libbrotlicommon    pkgs/main/linux-64::libbrotlicommon-1.0.9-h5eee18b_7 
  libbrotlidec       pkgs/main/linux-64::libbrotlidec-1.0.9-h5eee18b_7 
  libbrotlienc       pkgs/main/linux-64::libbrotlienc-1.0.9-h5eee18b_7 
  libclang           pkgs/main/linux-64::libclang-14.0.6-default_hc6dbbc7_1 
  libclang13         pkgs/main/linux-64::libclang13-14.0.6-default_he11475f_1 
  libdeflate         pkgs/main/linux-64::libdeflate-1.17-h5eee18b_0 
  libedit            pkgs/main/linux-64::libedit-3.1.20221030-h5eee18b_0 
  libevent           pkgs/main/linux-64::libevent-2.1.12-h8f2d780_0 
  libgfortran-ng     pkgs/main/linux-64::libgfortran-ng-11.2.0-h00389a5_1 
  libgfortran5       pkgs/main/linux-64::libgfortran5-11.2.0-h1234567_1 
  libllvm14          pkgs/main/linux-64::libllvm14-14.0.6-hdb19cb5_3 
  libpng             pkgs/main/linux-64::libpng-1.6.39-h5eee18b_0 
  libpq              pkgs/main/linux-64::libpq-12.9-h16c4e8d_3 
  libtiff            pkgs/main/linux-64::libtiff-4.5.0-h6a678d5_2 
  libwebp            pkgs/main/linux-64::libwebp-1.2.4-h11a3e52_1 
  libwebp-base       pkgs/main/linux-64::libwebp-base-1.2.4-h5eee18b_1 
  libxcb             pkgs/main/linux-64::libxcb-1.15-h7f8727e_0 
  libxkbcommon       pkgs/main/linux-64::libxkbcommon-1.0.1-h5eee18b_1 
  libxml2            pkgs/main/linux-64::libxml2-2.10.3-hcbfbd50_0 
  libxslt            pkgs/main/linux-64::libxslt-1.1.37-h2085143_0 
  lz4-c              pkgs/main/linux-64::lz4-c-1.9.4-h6a678d5_0 
  matplotlib         pkgs/main/linux-64::matplotlib-3.7.1-py310h06a4308_1 
  matplotlib-base    pkgs/main/linux-64::matplotlib-base-3.7.1-py310h1128e8f_1 
  mkl                pkgs/main/linux-64::mkl-2023.1.0-h6d00ec8_46342 
  mkl-service        pkgs/main/linux-64::mkl-service-2.4.0-py310h5eee18b_1 
  mkl_fft            pkgs/main/linux-64::mkl_fft-1.3.6-py310h1128e8f_1 
  mkl_random         pkgs/main/linux-64::mkl_random-1.2.2-py310h1128e8f_1 
  munkres            pkgs/main/noarch::munkres-1.1.4-py_0 
  nspr               pkgs/main/linux-64::nspr-4.35-h6a678d5_0 
  nss                pkgs/main/linux-64::nss-3.89.1-h6a678d5_0 
  numpy              pkgs/main/linux-64::numpy-1.24.3-py310h5f9d8c6_1 
  numpy-base         pkgs/main/linux-64::numpy-base-1.24.3-py310hb5e798b_1 
  pcre               pkgs/main/linux-64::pcre-8.45-h295c915_0 
  pillow             pkgs/main/linux-64::pillow-9.4.0-py310h6a678d5_0 
  ply                pkgs/main/linux-64::ply-3.11-py310h06a4308_0 
  pooch              pkgs/main/noarch::pooch-1.4.0-pyhd3eb1b0_0 
  pyparsing          pkgs/main/linux-64::pyparsing-3.0.9-py310h06a4308_0 
  pyqt               pkgs/main/linux-64::pyqt-5.15.7-py310h6a678d5_1 
  pyqt5-sip          pkgs/main/linux-64::pyqt5-sip-12.11.0-py310h6a678d5_1 
  qt-main            pkgs/main/linux-64::qt-main-5.15.2-h8373d8f_8 
  qt-webengine       pkgs/main/linux-64::qt-webengine-5.15.9-hbbf29b9_6 
  qtwebkit           pkgs/main/linux-64::qtwebkit-5.212-h3fafdc1_5 
  scipy              pkgs/main/linux-64::scipy-1.10.1-py310h5f9d8c6_1 
  sip                pkgs/main/linux-64::sip-6.6.2-py310h6a678d5_0 
  tbb                pkgs/main/linux-64::tbb-2021.8.0-hdb19cb5_0 
  toml               pkgs/main/noarch::toml-0.10.2-pyhd3eb1b0_0 
  zstd               pkgs/main/linux-64::zstd-1.5.5-hc292b87_0 

The following packages will be UPDATED:

  ca-certificates    conda-forge::ca-certificates-2023.5.7~ --> pkgs/main::ca-certificates-2023.05.30-h06a4308_0 
  requests                           2.28.1-py310h06a4308_1 --> 2.29.0-py310h06a4308_0 
  sqlite                                  3.41.1-h5eee18b_0 --> 3.41.2-h5eee18b_0 

The following packages will be SUPERSEDED by a higher-priority channel:

  certifi            conda-forge/noarch::certifi-2023.5.7-~ --> pkgs/main/linux-64::certifi-2023.5.7-py310h06a4308_0 
  conda              conda-forge::conda-23.5.0-py310hff520~ --> pkgs/main::conda-23.5.0-py310h06a4308_0 



Downloading and Extracting Packages
mkl-service-2.4.0    | 54 KB     |                                       |   0% 

blas-1.0             | 6 KB      |                                       |   0% 


matplotlib-base-3.7. | 6.7 MB    |                                       |   0% 



qt-main-5.15.2       | 53.7 MB   |                                       |   0% 




pyqt-5.15.7          | 5.1 MB    |                                       |   0% 





conda-23.5.0         | 1.0 MB    |                                       |   0% 






qtwebkit-5.212       | 16.2 MB   |                                       |   0% 







lcms2-2.12           | 312 KB    |                                       |   0% 








gst-plugins-base-1.1 | 2.2 MB    |                                       |   0% 









libpng-1.6.39        | 304 KB    |                                       |   0% 










libxkbcommon-1.0.1   | 590 KB    |                                       |   0% 











fontconfig-2.14.1    | 281 KB    |                                       |   0% 












toml-0.10.2          | 20 KB     |                                       |   0% 













pooch-1.4.0          | 41 KB     |                                       |   0% 














libgfortran-ng-11.2. | 20 KB     |                                       |   0% 















brotli-1.0.9         | 18 KB     |                                       |   0% 
















icu-58.2             | 10.5 MB   |                                       |   0% 

















libbrotlidec-1.0.9   | 31 KB     |                                       |   0% 


















ca-certificates-2023 | 120 KB    |                                       |   0% 



















sqlite-3.41.2        | 1.2 MB    |                                       |   0% 




















libbrotlicommon-1.0. | 70 KB     |                                       |   0% 





















pcre-8.45            | 207 KB    |                                       |   0% 






















kiwisolver-1.4.4     | 76 KB     |                                       |   0% 























 ... (more hidden) ...
blas-1.0             | 6 KB      | ##################################### | 100% 


matplotlib-base-3.7. | 6.7 MB    | #1                                    |   3% 



qt-main-5.15.2       | 53.7 MB   |                                       |   0% 
mkl-service-2.4.0    | 54 KB     | ##################################### | 100% 




pyqt-5.15.7          | 5.1 MB    | 7                                     |   2% 
mkl-service-2.4.0    | 54 KB     | ##################################### | 100% 





conda-23.5.0         | 1.0 MB    | ####                                  |  11% 






qtwebkit-5.212       | 16.2 MB   | 1                                     |   0% 
lcms2-2.12           | 312 KB    | #9                                    |   5% 
matplotlib-base-3.7. | 6.7 MB    | #########1                            |  25% 



qt-main-5.15.2       | 53.7 MB   | #4                                    |   4% 




pyqt-5.15.7          | 5.1 MB    | ###############2                      |  41% 






qtwebkit-5.212       | 16.2 MB   | ######2                               |  17% 








gst-plugins-base-1.1 | 2.2 MB    | 2                                     |   1% 
lcms2-2.12           | 312 KB    | ##################################### | 100% 







lcms2-2.12           | 312 KB    | ##################################### | 100% 
matplotlib-base-3.7. | 6.7 MB    | ####################1                 |  54% 



qt-main-5.15.2       | 53.7 MB   | ##6                                   |   7% 






qtwebkit-5.212       | 16.2 MB   | ##########1                           |  27% 








gst-plugins-base-1.1 | 2.2 MB    | ###################################4  |  96% 
libpng-1.6.39        | 304 KB    | #9                                    |   5% 










libxkbcommon-1.0.1   | 590 KB    | #                                     |   3% 











fontconfig-2.14.1    | 281 KB    | ##1                                   |   6% 
matplotlib-base-3.7. | 6.7 MB    | ################################3     |  87% 



qt-main-5.15.2       | 53.7 MB   | ###7                                  |  10% 












toml-0.10.2          | 20 KB     | #############################8        |  81% 






qtwebkit-5.212       | 16.2 MB   | ##############4                       |  39% 













pooch-1.4.0          | 41 KB     | ##############4                       |  39% 














libgfortran-ng-11.2. | 20 KB     | #############################9        |  81% 
brotli-1.0.9         | 18 KB     | ################################1     |  87% 
















icu-58.2             | 10.5 MB   |                                       |   0% 

















libbrotlidec-1.0.9   | 31 KB     | ##################9                   |  51% 


















ca-certificates-2023 | 120 KB    | ####9                                 |  13% 
sqlite-3.41.2        | 1.2 MB    | 4                                     |   1% 



qt-main-5.15.2       | 53.7 MB   | ####9                                 |  13% 




















libbrotlicommon-1.0. | 70 KB     | ########4                             |  23% 






qtwebkit-5.212       | 16.2 MB   | ##################8                   |  51% 
pcre-8.45            | 207 KB    | ##8                                   |   8% 
















icu-58.2             | 10.5 MB   | #############8                        |  37% 






















kiwisolver-1.4.4     | 76 KB     | #######7                              |  21% 
qt-main-5.15.2       | 53.7 MB   | ######3                               |  17% 






qtwebkit-5.212       | 16.2 MB   | #######################1              |  63% 























 ... (more hidden) ...
icu-58.2             | 10.5 MB   | #####################9                |  59% 
qt-main-5.15.2       | 53.7 MB   | #######6                              |  21% 






qtwebkit-5.212       | 16.2 MB   | ###########################6          |  75% 























 ... (more hidden) ...
qt-main-5.15.2       | 53.7 MB   | #########5                            |  26% 
















icu-58.2             | 10.5 MB   | #############################6        |  80% 






qtwebkit-5.212       | 16.2 MB   | ################################      |  87% 
qt-main-5.15.2       | 53.7 MB   | ###########1                          |  30% 






qtwebkit-5.212       | 16.2 MB   | ####################################4 |  98% 
qt-main-5.15.2       | 53.7 MB   | ############7                         |  34% 
qt-main-5.15.2       | 53.7 MB   | ##############2                       |  38% 
qt-main-5.15.2       | 53.7 MB   | ################2                     |  44% 
qt-main-5.15.2       | 53.7 MB   | #################8                    |  48% 
conda-23.5.0         | 1.0 MB    | ##################################### | 100% 





conda-23.5.0         | 1.0 MB    | ##################################### | 100% 
qt-main-5.15.2       | 53.7 MB   | ###################1                  |  52% 
qt-main-5.15.2       | 53.7 MB   | ####################4                 |  55% 
qt-main-5.15.2       | 53.7 MB   | #####################6                |  59% 
qt-main-5.15.2       | 53.7 MB   | ######################9               |  62% 
qt-main-5.15.2       | 53.7 MB   | ########################9             |  67% 
qt-main-5.15.2       | 53.7 MB   | ##########################3           |  71% 
gst-plugins-base-1.1 | 2.2 MB    | ##################################### | 100% 
qt-main-5.15.2       | 53.7 MB   | ###########################6          |  75% 
libpng-1.6.39        | 304 KB    | ##################################### | 100% 









libpng-1.6.39        | 304 KB    | ##################################### | 100% 
qt-main-5.15.2       | 53.7 MB   | ############################7         |  78% 
qt-main-5.15.2       | 53.7 MB   | #############################9        |  81% 
fontconfig-2.14.1    | 281 KB    | ##################################### | 100% 



qt-main-5.15.2       | 53.7 MB   | ###############################       |  84% 











fontconfig-2.14.1    | 281 KB    | ##################################### | 100% 
qt-main-5.15.2       | 53.7 MB   | ################################1     |  87% 










libxkbcommon-1.0.1   | 590 KB    | ##################################### | 100% 
libxkbcommon-1.0.1   | 590 KB    | ##################################### | 100% 




pyqt-5.15.7          | 5.1 MB    | ##################################### | 100% 




pyqt-5.15.7          | 5.1 MB    | ##################################### | 100% 
qt-main-5.15.2       | 53.7 MB   | #################################5    |  91% 
toml-0.10.2          | 20 KB     | ##################################### | 100% 













pooch-1.4.0          | 41 KB     | ##################################### | 100% 













pooch-1.4.0          | 41 KB     | ##################################### | 100% 
libgfortran-ng-11.2. | 20 KB     | ##################################### | 100% 














libgfortran-ng-11.2. | 20 KB     | ##################################### | 100% 



qt-main-5.15.2       | 53.7 MB   | ##################################7   |  94% 
brotli-1.0.9         | 18 KB     | ##################################### | 100% 
qt-main-5.15.2       | 53.7 MB   | ###################################7  |  97% 
libbrotlidec-1.0.9   | 31 KB     | ##################################### | 100% 

















libbrotlidec-1.0.9   | 31 KB     | ##################################### | 100% 
ca-certificates-2023 | 120 KB    | ##################################### | 100% 


















ca-certificates-2023 | 120 KB    | ##################################### | 100% 



qt-main-5.15.2       | 53.7 MB   | ####################################8 | 100% 
libbrotlicommon-1.0. | 70 KB     | ##################################### | 100% 




















libbrotlicommon-1.0. | 70 KB     | ##################################### | 100% 
pcre-8.45            | 207 KB    | ##################################### | 100% 





















pcre-8.45            | 207 KB    | ##################################### | 100% 
sqlite-3.41.2        | 1.2 MB    | ##################################### | 100% 



















sqlite-3.41.2        | 1.2 MB    | ##################################### | 100% 






















kiwisolver-1.4.4     | 76 KB     | ##################################### | 100% 
kiwisolver-1.4.4     | 76 KB     | ##################################### | 100% 
matplotlib-base-3.7. | 6.7 MB    | ##################################### | 100% 
 ... (more hidden) ...
icu-58.2             | 10.5 MB   | ##################################### | 100% 
















icu-58.2             | 10.5 MB   | ##################################### | 100% 
qtwebkit-5.212       | 16.2 MB   | ##################################### | 100% 
qt-main-5.15.2       | 53.7 MB   | ##################################### | 100% 
                      

                                                                                


                                                                                


                                                                                



                                                                                




                                                                                





                                                                                






                                                                                







                                                                                








                                                                                









                                                                                










                                                                                











                                                                                












                                                                                













                                                                                














                                                                                















                                                                                
















                                                                                

















                                                                                


















                                                                                



















                                                                                




















                                                                                





















                                                                                






















                                                                                



























































































































































































































































































































































































































































































































































































Preparing transaction: \ 
| 
/ 
- 
\ 
| 
/ 
- 
done
Verifying transaction: | 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
done
Executing transaction: \ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
\ 
| 
/ 
- 
done
# Install PyKalman (https://pykalman.github.io/)
!pip install pykalman --quiet

# Imports
import numpy as np
import matplotlib.pyplot as plt
import pykalman
from scipy import stats
[notice] A new release of pip is available: 22.0.4 -> 23.1.2
[notice] To update, run: pip install --upgrade pip

Figure settings

#@title Figure settings
import ipywidgets as widgets       # interactive display
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

Data retrieval and loading

#@title Data retrieval and loading
import io
import os
import hashlib
import requests

fname = "W2D3_mit_eyetracking_2009.npz"
url = "https://osf.io/jfk8w/download"
expected_md5 = "20c7bc4a6f61f49450997e381cf5e0dd"

if not os.path.isfile(fname):
  try:
    r = requests.get(url)
  except requests.ConnectionError:
    print("!!! Failed to download data !!!")
  else:
    if r.status_code != requests.codes.ok:
      print("!!! Failed to download data !!!")
    elif hashlib.md5(r.content).hexdigest() != expected_md5:
      print("!!! Data download appears corrupted !!!")
    else:
      with open(fname, "wb") as fid:
        fid.write(r.content)

def load_eyetracking_data(data_fname=fname):

  with np.load(data_fname, allow_pickle=True) as dobj:
    data = dict(**dobj)

  images = [plt.imread(io.BytesIO(stim), format='JPG')
            for stim in data['stimuli']]
  subjects = data['subjects']

  return subjects, images

Helper functions

#@title Helper functions
np.set_printoptions(precision=3)


def plot_kalman(state, observation, estimate=None, label='filter', color='r-',
                title='LDS', axes=None):
    if axes is None:
      fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 6))
      ax1.plot(state[:, 0], state[:, 1], 'g-', label='true latent')
      ax1.plot(observation[:, 0], observation[:, 1], 'k.', label='data')
    else:
      ax1, ax2 = axes

    if estimate is not None:
      ax1.plot(estimate[:, 0], estimate[:, 1], color=color, label=label)
    ax1.set(title=title, xlabel='X position', ylabel='Y position')
    ax1.legend()

    if estimate is None:
      ax2.plot(state[:, 0], observation[:, 0], '.k', label='dim 1')
      ax2.plot(state[:, 1], observation[:, 1], '.', color='grey', label='dim 2')
      ax2.set(title='correlation', xlabel='latent', ylabel='measured')
    else:
      ax2.plot(state[:, 0], estimate[:, 0], '.', color=color,
               label='latent dim 1')
      ax2.plot(state[:, 1], estimate[:, 1], 'x', color=color,
               label='latent dim 2')
      ax2.set(title='correlation',
              xlabel='real latent',
              ylabel='estimated latent')
    ax2.legend()

    return ax1, ax2


def plot_gaze_data(data, img=None, ax=None):
    # overlay gaze on stimulus
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))

    xlim = None
    ylim = None
    if img is not None:
        ax.imshow(img, aspect='auto')
        ylim = (img.shape[0], 0)
        xlim = (0, img.shape[1])

    ax.scatter(data[:, 0], data[:, 1], c='m', s=100, alpha=0.7)
    ax.set(xlim=xlim, ylim=ylim)

    return ax


def plot_kf_state(kf, data, ax):
    mu_0 = np.ones(kf.n_dim_state)
    mu_0[:data.shape[1]] = data[0]
    kf.initial_state_mean = mu_0

    mu, sigma = kf.smooth(data)
    ax.plot(mu[:, 0], mu[:, 1], 'limegreen', linewidth=3, zorder=1)
    ax.scatter(mu[0, 0], mu[0, 1], c='orange', marker='>', s=200, zorder=2)
    ax.scatter(mu[-1, 0], mu[-1, 1], c='orange', marker='s', s=200, zorder=2)

Section 1: Linear Dynamical System (LDS)

Video 2: Linear Dynamical Systems

Video available at https://youtu.be/2SWh639YgEg

Kalman filter definitions:

The latent state \(s_t\) evolves as a stochastic linear dynamical system in discrete time, with a dynamics matrix \(D\):

(384)\[\begin{equation} s_t = Ds_{t-1}+w_t \end{equation}\]

Just as in the HMM, the structure is a Markov chain where the state at time point \(t\) is conditionally independent of previous states given the state at time point \(t-1\).

Sensory measurements \(m_t\) (observations) are noisy linear projections of the latent state:

(385)\[\begin{equation} m_t = Hs_{t}+\eta_t \end{equation}\]

Both states and measurements have Gaussian variability, often called noise: ‘process noise’ \(w_t\) for the states, and ‘measurement’ or ‘observation noise’ \(\eta_t\) for the measurements. The initial state is also Gaussian distributed. These quantites have means and covariances:

(386)\[\begin{eqnarray} w_t & \sim & \mathcal{N}(0, Q) \\ \eta_t & \sim & \mathcal{N}(0, R) \\ s_0 & \sim & \mathcal{N}(\mu_0, \Sigma_0) \end{eqnarray}\]

As a consequence, \(s_t\), \(m_t\) and their joint distributions are Gaussian. This makes all of the math analytically tractable using linear algebra, so we can easily compute the marginal and conditional distributions we will use for inferring the current state given the entire history of measurements.

Please note: we are trying to create uniform notation across tutorials. In some videos created in 2020, measurements \(m_t\) were denoted \(y_t\), and the Dynamics matrix \(D\) was denoted \(F\). We apologize for any confusion!

Section 1.1: Sampling from a latent linear dynamical system

The first thing we will investigate is how to generate timecourse samples from a linear dynamical system given its parameters. We will start by defining the following system:

# task dimensions
n_dim_state = 2
n_dim_obs = 2

# initialize model parameters
params = {
  'D': 0.9 * np.eye(n_dim_state),  # state transition matrix
  'Q': np.eye(n_dim_obs),  # state noise covariance
  'H': np.eye(n_dim_state),  # observation matrix
  'R': 1.0 * np.eye(n_dim_obs),  # observation noise covariance
  'mu_0': np.zeros(n_dim_state),  # initial state mean
  'sigma_0': 0.1 * np.eye(n_dim_state),  # initial state noise covariance
}

Coding note: We used a parameter dictionary params above. As the number of parameters we need to provide to our functions increases, it can be beneficial to condense them into a data structure like this to clean up the number of inputs we pass in. The trade-off is that we have to know what is in our data structure to use those values, rather than looking at the function signature directly.

Exercise 1: Sampling from a linear dynamical system

In this exercise you will implement the dynamics functions of a linear dynamical system to sample both a latent space trajectory (given parameters set above) and noisy measurements.


def sample_lds(n_timesteps, params, seed=0):
  """ Generate samples from a Linear Dynamical System specified by the provided
  parameters.

  Args:
  n_timesteps (int): the number of time steps to simulate
  params (dict): a dictionary of model parameters: (D, Q, H, R, mu_0, sigma_0)
  seed (int): a random seed to use for reproducibility checks

  Returns:
  ndarray, ndarray: the generated state and observation data
  """
  n_dim_state = params['D'].shape[0]
  n_dim_obs = params['H'].shape[0]

  # set seed
  np.random.seed(seed)

  # precompute random samples from the provided covariance matrices
  # mean defaults to 0
  mi = stats.multivariate_normal(cov=params['Q']).rvs(n_timesteps)
  eta = stats.multivariate_normal(cov=params['R']).rvs(n_timesteps)

  # initialize state and observation arrays
  state = np.zeros((n_timesteps, n_dim_state))
  obs = np.zeros((n_timesteps, n_dim_obs))

  ###################################################################
  ## TODO for students: compute the next state and observation values
  # Fill out function and remove
  raise NotImplementedError("Student exercise: compute the next state and observation values")
  ###################################################################

  # simulate the system
  for t in range(n_timesteps):
    # write the expressions for computing state values given the time step
    if t == 0:
      state[t] = ...
    else:
      state[t] = ...

    # write the expression for computing the observation
    obs[t] = ...

  return state, obs


# Uncomment below to test your function
# state, obs = sample_lds(100, params)
# print('sample at t=3 ', state[3])
# plot_kalman(state, obs, title='sample')

# to_remove solution
def sample_lds(n_timesteps, params, seed=0):
  """ Generate samples from a Linear Dynamical System specified by the provided
  parameters.

  Args:
  n_timesteps (int): the number of time steps to simulate
  params (dict): a dictionary of model parameters: (D, Q, H, R, mu_0, sigma_0)
  seed (int): a random seed to use for reproducibility checks

  Returns:
  ndarray, ndarray: the generated state and observation data
  """
  n_dim_state = params['D'].shape[0]
  n_dim_obs = params['H'].shape[0]

  # set seed
  np.random.seed(seed)

  # precompute random samples from the provided covariance matrices
  # mean defaults to 0
  mi = stats.multivariate_normal(cov=params['Q']).rvs(n_timesteps)
  eta = stats.multivariate_normal(cov=params['R']).rvs(n_timesteps)

  # initialize state and observation arrays
  state = np.zeros((n_timesteps, n_dim_state))
  obs = np.zeros((n_timesteps, n_dim_obs))

  # simulate the system
  for t in range(n_timesteps):
    # write the expressions for computing state values given the time step
    if t == 0:
      state[t] = stats.multivariate_normal(mean=params['mu_0'],
                                           cov=params['sigma_0']).rvs(1)
    else:
      state[t] = params['D'] @ state[t-1] + mi[t]

    # write the expression for computing the observation
    obs[t] = params['H'] @ state[t] + eta[t]

  return state, obs


state, obs = sample_lds(100, params)
print('sample at t=3 ', state[3])
with plt.xkcd():
  plot_kalman(state, obs, title='sample')
sample at t=3  [3.286 0.527]
../../../_images/W3D2_Tutorial4_24_2.png

Interactive Demo: Adjusting System Dynamics

To test your understanding of the parameters of a linear dynamical system, think about what you would expect if you made the following changes:

  1. Reduce observation noise \(R\)

  2. Increase respective temporal dynamics \(D\)

Use the interactive widget below to vary the values of \(R\) and \(D\).

Make sure you execute this cell to enable the widget!

#@title

#@markdown Make sure you execute this cell to enable the widget!

@widgets.interact(R=widgets.FloatLogSlider(1., min=-2, max=2),
                  D=widgets.FloatSlider(0.9, min=0.0, max=1.0, step=.01))
def explore_dynamics(R=0.1, D=0.5):
    params = {
    'D': D * np.eye(n_dim_state),  # state transition matrix
    'Q': np.eye(n_dim_obs),  # state noise covariance
    'H': np.eye(n_dim_state),  # observation matrix
    'R': R * np.eye(n_dim_obs),  # observation noise covariance
    'mu_0': np.zeros(n_dim_state),  # initial state mean,
    'sigma_0': 0.1 * np.eye(n_dim_state),  # initial state noise covariance
    }

    state, obs = sample_lds(100, params)
    plot_kalman(state, obs, title='sample')

Section 2: Kalman Filtering

Video 3: Kalman Filtering

Video available at https://youtu.be/VboZOV9QMOI

We want to infer the latent state variable \(s_t\) given the measured (observed) variable \(m_t\).

(387)\[\begin{equation} P(s_t|m_1, ..., m_t, m_{t+1}, ..., m_T)\sim \mathcal{N}(\hat{\mu}_t, \hat{\Sigma_t}) \end{equation}\]

First we obtain estimates of the latent state by running the filtering from \(t=0,....T\).

(388)\[\begin{equation} s_t^{\rm pred}\sim \mathcal{N}(\hat{\mu}_t^{\rm pred},\hat{\Sigma}_t^{\rm pred})\end{equation}\]

Where \(\hat{\mu}_t^{\rm pred}\) and \(\hat{\Sigma}_t^{\rm pred}\) are derived as follows:

(389)\[\begin{eqnarray} \hat{\mu}_1^{\rm pred} & = & D\hat{\mu}_{0} \\ \hat{\mu}_t^{\rm pred} & = & D\hat{\mu}_{t-1} \end{eqnarray}\]

This is the prediction for \(s_t\) obtained simply by taking the expected value of \(s_{t-1}\) and projecting it forward one step using the transition matrix \(D\). We do the same for the covariance, taking into account the noise covariance \(Q\) and the fact that scaling a variable by \(D\) scales its covariance \(\Sigma\) as \(D\Sigma D^T\):

(390)\[\begin{eqnarray} \hat{\Sigma}_0^{\rm pred} & = & D\hat{\Sigma}_{0}D^T+Q \\ \hat{\Sigma}_t^{\rm pred} & = & D\hat{\Sigma}_{t-1}D^T+Q \end{eqnarray}\]

We then use a Bayesian update from the newest measurements to obtain \(\hat{\mu}_t^{\rm filter}\) and \(\hat{\Sigma}_t^{\rm filter}\)

Project our prediction to observational space:

(391)\[\begin{equation} m_t^{\rm pred}\sim \mathcal{N}(H\hat{\mu}_t^{\rm pred}, H\hat{\Sigma}_t^{\rm pred}H^T+R) \end{equation}\]

update prediction by actual data:

(392)\[\begin{eqnarray} s_t^{\rm filter} & \sim & \mathcal{N}(\hat{\mu}_t^{\rm filter}, \hat{\Sigma}_t^{\rm filter}) \\ \hat{\mu}_t^{\rm filter} & = & \hat{\mu}_t^{\rm pred}+K_t(m_t-H\hat{\mu}_t^{\rm pred}) \\ \hat{\Sigma}_t^{\rm filter} & = & (I-K_tH)\hat{\Sigma}_t^{\rm pred} \end{eqnarray}\]

Kalman gain matrix:

(393)\[\begin{equation} K_t=\hat{\Sigma}_t^{\rm pred}H^T(H\hat{\Sigma}_t^{\rm pred}H^T+R)^{-1} \end{equation}\]

We use the latent-only prediction to project it to the observational space and compute a correction proportional to the error \(m_t-HDz_{t-1}\) between prediction and data. The coefficient of this correction is the Kalman gain matrix.

Interpretations

If measurement noise is small and dynamics are fast, then estimation will depend mostly on currently observed data. If the measurement noise is large, then the Kalman filter uses past observations as well, combining them as long as the underlying state is at least somewhat predictable.

In order to explore the impact of filtering, we will use the following noisy oscillatory system:

# task dimensions
n_dim_state = 2
n_dim_obs = 2

T=100

# initialize model parameters
params = {
  'D': np.array([[1., 1.], [-(2*np.pi/20.)**2., .9]]),  # state transition matrix
  'Q': np.eye(n_dim_obs),                               # state noise covariance
  'H': np.eye(n_dim_state),                             # observation matrix
  'R': 100.0 * np.eye(n_dim_obs),                       # observation noise covariance
  'mu_0': np.zeros(n_dim_state),                        # initial state mean
  'sigma_0': 0.1 * np.eye(n_dim_state),                 # initial state noise covariance
}

state, obs = sample_lds(T, params)
plot_kalman(state, obs, title='sample')
(<AxesSubplot:title={'center':'sample'}, xlabel='X position', ylabel='Y position'>,
 <AxesSubplot:title={'center':'correlation'}, xlabel='latent', ylabel='measured'>)
../../../_images/W3D2_Tutorial4_37_1.png

Exercise 2: Implement Kalman filtering

In this exercise you will implement the Kalman filter (forward) process. Your focus will be on writing the expressions for the Kalman gain, filter mean, and filter covariance at each time step (refer to the equations above).


def kalman_filter(data, params):
  """ Perform Kalman filtering (forward pass) on the data given the provided
  system parameters.

  Args:
    data (ndarray): a sequence of observations of shape(n_timesteps, n_dim_obs)
    params (dict): a dictionary of model parameters: (D, Q, H, R, mu_0, sigma_0)

  Returns:
    ndarray, ndarray: the filtered system means and noise covariance values
  """
  # pulled out of the params dict for convenience
  D = params['D']
  Q = params['Q']
  H = params['H']
  R = params['R']

  n_dim_state = D.shape[0]
  n_dim_obs = H.shape[0]
  I = np.eye(n_dim_state)  # identity matrix

  # state tracking arrays
  mu = np.zeros((len(data), n_dim_state))
  sigma = np.zeros((len(data), n_dim_state, n_dim_state))

  # filter the data
  for t, y in enumerate(data):
    if t == 0:
      mu_pred = params['mu_0']
      sigma_pred = params['sigma_0']
    else:
      mu_pred = D @ mu[t-1]
      sigma_pred = D @ sigma[t-1] @ D.T + Q

    ###########################################################################
    ## TODO for students: compute the filtered state mean and covariance values
    # Fill out function and remove
    raise NotImplementedError("Student exercise: compute the filtered state mean and covariance values")
    ###########################################################################
    # write the expression for computing the Kalman gain
    K = ...
    # write the expression for computing the filtered state mean
    mu[t] = ...
    # write the expression for computing the filtered state noise covariance
    sigma[t] = ...

  return mu, sigma


# Uncomment below to test your function
# filtered_state_means, filtered_state_covariances = kalman_filter(obs, params)
# plot_kalman(state, obs, filtered_state_means, title="my kf-filter",
#             color='r', label='my kf-filter')

# to_remove solution
def kalman_filter(data, params):
  """ Perform Kalman filtering (forward pass) on the data given the provided
  system parameters.

  Args:
    data (ndarray): a sequence of observations of shape(n_timesteps, n_dim_obs)
    params (dict): a dictionary of model parameters: (D, Q, H, R, mu_0, sigma_0)

  Returns:
    ndarray, ndarray: the filtered system means and noise covariance values
  """
  # pulled out of the params dict for convenience
  D = params['D']
  Q = params['Q']
  H = params['H']
  R = params['R']

  n_dim_state = D.shape[0]
  n_dim_obs = H.shape[0]
  I = np.eye(n_dim_state)  # identity matrix

  # state tracking arrays
  mu = np.zeros((len(data), n_dim_state))
  sigma = np.zeros((len(data), n_dim_state, n_dim_state))

  # filter the data
  for t, y in enumerate(data):
    if t == 0:
      mu_pred = params['mu_0']
      sigma_pred = params['sigma_0']
    else:
      mu_pred = D @ mu[t-1]
      sigma_pred = D @ sigma[t-1] @ D.T + Q

    # write the expression for computing the Kalman gain
    K = sigma_pred @ H.T @ np.linalg.inv(H @ sigma_pred @ H.T + R)
    # write the expression for computing the filtered state mean
    mu[t] = mu_pred + K @ (y - H @ mu_pred)
    # write the expression for computing the filtered state noise covariance
    sigma[t] = (I - K @ H) @ sigma_pred

  return mu, sigma


filtered_state_means, filtered_state_covariances = kalman_filter(obs, params)
with plt.xkcd():
  plot_kalman(state, obs, filtered_state_means, title="my kf-filter",
              color='r', label='my kf-filter')
../../../_images/W3D2_Tutorial4_40_1.png

Section 3: Fitting Eye Gaze Data

Video 4: Fitting Eye Gaze Data

Video available at https://youtu.be/M7OuXmVWHGI

Tracking eye gaze is used in both experimental and user interface applications. Getting an accurate estimation of where someone is looking on a screen in pixel coordinates can be challenging, however, due to the various sources of noise inherent in obtaining these measurements. A main source of noise is the general accuracy of the eye tracker device itself and how well it maintains calibration over time. Changes in ambient light or subject position can further reduce accuracy of the sensor. Eye blinks introduce a different form of noise as interruptions in the data stream which also need to be addressed.

Fortunately we have a candidate solution for handling noisy eye gaze data in the Kalman filter we just learned about. Let’s look at how we can apply these methods to a small subset of data taken from the MIT Eyetracking Database [Judd et al. 2009]. This data was collected as part of an effort to model visual saliency – given an image, can we predict where a person is most likely going to look.

# load eyetracking data
subjects, images = load_eyetracking_data()

Interactive Demo: Tracking Eye Gaze

We have three stimulus images and five different subjects’ gaze data. Each subject fixated in the center of the screen before the image appeared, then had a few seconds to freely look around. You can use the widget below to see how different subjects visually scanned the presented image. A subject ID of -1 will show the stimulus images without any overlayed gaze trace.

Note that the images are rescaled below for display purposes, they were in their original aspect ratio during the task itself.

Make sure you execute this cell to enable the widget!

#@title

#@markdown Make sure you execute this cell to enable the widget!

@widgets.interact(subject_id=widgets.IntSlider(-1, min=-1, max=4),
                  image_id=widgets.IntSlider(0, min=0, max=2))
def plot_subject_trace(subject_id=-1, image_id=0):
  if subject_id == -1:
    subject = np.zeros((3, 0, 2))
  else:
    subject = subjects[subject_id]
  data = subject[image_id]
  img = images[image_id]

  fig, ax = plt.subplots()
  ax.imshow(img, aspect='auto')
  ax.scatter(data[:, 0], data[:, 1], c='m', s=100, alpha=0.7)
  ax.set(xlim=(0, img.shape[1]), ylim=(img.shape[0], 0))

Section 3.1: Fitting data with pykalman

Now that we have data, we’d like to use Kalman filtering to give us a better estimate of the true gaze. Up until this point we’ve known the parameters of our LDS, but here we need to estimate them from data directly. We will use the pykalman package to handle this estimation using the EM algorithm, a useful and influential learning algorithm described briefly in the bonus material.

Before exploring fitting models with pykalman it’s worth pointing out some naming conventions used by the library:

(394)\[\begin{align} D &: \texttt{transition_matrices} & Q &: \texttt{transition_covariance} \\ H &: \texttt{observation_matrices} & R &: \texttt{observation_covariance} \\ \mu_0 &: \texttt{initial_state_mean} & \Sigma_0 &: \texttt{initial_state_covariance} \end{align}\]

The first thing we need to do is provide a guess at the dimensionality of the latent state. Let’s start by assuming the dynamics line-up directly with the observation data (pixel x,y-coordinates), and so we have a state dimension of 2.

We also need to decide which parameters we want the EM algorithm to fit. In this case, we will let the EM algorithm discover the dynamics parameters i.e. the \(D\), \(Q\), \(H\), and \(R\) matrices.

We set up our pykalman KalmanFilter object with these settings using the code below.

# set up our KalmanFilter object and tell it which parameters we want to
# estimate
np.random.seed(1)

n_dim_obs = 2
n_dim_state = 2

kf = pykalman.KalmanFilter(
  n_dim_state=n_dim_state,
  n_dim_obs=n_dim_obs,
  em_vars=['transition_matrices', 'transition_covariance',
           'observation_matrices', 'observation_covariance']
)

Because we know from the reported experimental design that subjects fixated in the center of the screen right before the image appears, we can set the initial starting state estimate \(\mu_0\) as being the center pixel of the stimulus image (the first data point in this sample dataset) with a correspondingly low initial noise covariance \(\Sigma_0\). Once we have everything set, it’s time to fit some data.

# Choose a subject and stimulus image
subject_id = 1
image_id = 2
data = subjects[subject_id][image_id]

# Provide the initial states
kf.initial_state_mean = data[0]
kf.initial_state_covariance = 0.1*np.eye(n_dim_state)

# Estimate the parameters from data using the EM algorithm
kf.em(data)

print(f'D=\n{kf.transition_matrices}')
print(f'Q =\n{kf.transition_covariance}')
print(f'H =\n{kf.observation_matrices}')
print(f'R =\n{kf.observation_covariance}')
D=
[[ 1.004 -0.01 ]
 [ 0.005  0.989]]
Q =
[[278.016 219.292]
 [219.292 389.774]]
H =
[[ 0.999  0.003]
 [-0.004  1.01 ]]
R =
[[26.026 19.596]
 [19.596 26.745]]

We see that the EM algorithm has found fits for the various dynamics parameters. One thing you will note is that both the state and observation matrices are close to the identity matrix, which means the x- and y-coordinate dynamics are independent of each other and primarily impacted by the noise covariances.

We can now use this model to smooth the observed data from the subject. In addition to the source image, we can also see how this model will work with the gaze recorded by the same subject on the other images as well, or even with different subjects.

Below are the three stimulus images overlayed with recorded gaze in magenta and smoothed state from the filter in green, with gaze begin (orange triangle) and gaze end (orange square) markers.

Make sure you execute this cell to enable the widget!

#@title

#@markdown Make sure you execute this cell to enable the widget!

@widgets.interact(subject_id=widgets.IntSlider(1, min=0, max=4))
def plot_smoothed_traces(subject_id=0):
  subject = subjects[subject_id]
  fig, axes = plt.subplots(ncols=3, figsize=(18, 4))
  for data, img, ax in zip(subject, images, axes):
    ax = plot_gaze_data(data, img=img, ax=ax)
    plot_kf_state(kf, data, ax)

Discussion questions:

Why do you think one trace from one subject was sufficient to provide a decent fit across all subjects? If you were to go back and change the subject_id and/or image_id for when we fit the data using EM, do you think the fits would be different?

We don’t think the eye is exactly following a linear dynamical system. Nonetheless that is what we assumed for this exercise when we applied a Kalman filter. Despite the mismatch, these algorithms do perform well. Discuss what differences we might find between the true and assumed processes. What mistakes might be likely consequences of these differences?

Finally, recall that the original task was to use this data to help develop models of visual salience. While our Kalman filter is able to provide smooth estimates of observed gaze data, it’s not telling us anything about why the gaze is going in a certain direction. In fact, if we sample data from our parameters and plot them, we get what amounts to a random walk.

kf_state, kf_data = kf.sample(len(data))
ax = plot_gaze_data(kf_data, img=images[2])
plot_kf_state(kf, kf_data, ax)
../../../_images/W3D2_Tutorial4_60_0.png

This should not be surprising, as we have given the model no other observed data beyond the pixels at which gaze was detected. We expect there is some other aspect driving the latent state of where to look next other than just the previous fixation location.

In summary, while the Kalman filter is a good option for smoothing the gaze trajectory itself, especially if using a lower-quality eye tracker or in noisy environmental conditions, a linear dynamical system may not be the right way to approach the much more challenging task of modeling visual saliency.

Bonus

Review on Gaussian joint, marginal and conditional distributions

Assume

(395)\[\begin{eqnarray} z & = & \begin{bmatrix}x \\y\end{bmatrix}\sim N\left(\begin{bmatrix}a \\b\end{bmatrix}, \begin{bmatrix}A & C \\C^T & B\end{bmatrix}\right) \end{eqnarray}\]

then the marginal distributions are

(396)\[\begin{eqnarray} x & \sim & \mathcal{N}(a, A) \\ y & \sim & \mathcal{N}(b,B) \end{eqnarray}\]

and the conditional distributions are

(397)\[\begin{eqnarray} x|y & \sim & \mathcal{N}(a+CB^{-1}(y-b), A-CB^{-1}C^T) \\ y|x & \sim & \mathcal{N}(b+C^TA^{-1}(x-a), B-C^TA^{-1}C) \end{eqnarray}\]

important take away: given the joint Gaussian distribution we can derive the conditionals

Kalman Smoothing

Video 5: Kalman Smoothing and the EM Algorithm

Video available at https://youtu.be/4Ar2mYz1Nms

Obtain estimates by propagating from \(y_T\) back to \(y_0\) using results of forward pass (\(\hat{\mu}_t^{\rm filter}, \hat{\Sigma}_t^{\rm filter}, P_t=\hat{\Sigma}_{t+1}^{\rm pred}\))

(398)\[\begin{eqnarray} s_t & \sim & \mathcal{N}(\hat{\mu}_t^{\rm smooth}, \hat{\Sigma}_t^{\rm smooth}) \\ \hat{\mu}_t^{\rm smooth} & = & \hat{\mu}_t^{\rm filter}+J_t(\hat{\mu}_{t+1}^{\rm smooth}-D\hat{\mu}_t^{\rm filter}) \\ \hat{\Sigma}_t^{\rm smooth} & = & \hat{\Sigma}_t^{\rm filter}+J_t(\hat{\Sigma}_{t+1}^{\rm smooth}-P_t)J_t^T \\ J_t & = & \hat{\Sigma}_t^{\rm filter}D^T P_t^{-1} \end{eqnarray}\]

This gives us the final estimate for \(z_t\).

(399)\[\begin{eqnarray} \hat{\mu}_t & = & \hat{\mu}_t^{\rm smooth} \\ \hat{\Sigma}_t & = & \hat{\Sigma}_t^{\rm smooth} \end{eqnarray}\]

Exercise 3: Implement Kalman smoothing

In this exercise you will implement the Kalman smoothing (backward) process. Again you will focus on writing the expressions for computing the smoothed mean, smoothed covariance, and \(J_t\) values.


def kalman_smooth(data, params):
  """ Perform Kalman smoothing (backward pass) on the data given the provided
  system parameters.

  Args:
    data (ndarray): a sequence of observations of shape(n_timesteps, n_dim_obs)
    params (dict): a dictionary of model parameters: (D, Q, H, R, mu_0, sigma_0)

  Returns:
    ndarray, ndarray: the smoothed system means and noise covariance values
  """
  # pulled out of the params dict for convenience
  D= params['D']
  Q = params['Q']
  H = params['H']
  R = params['R']

  n_dim_state = D.shape[0]
  n_dim_obs = H.shape[0]

  # first run the forward pass to get the filtered means and covariances
  mu, sigma = kalman_filter(data, params)

  # initialize state mean and covariance estimates
  mu_hat = np.zeros_like(mu)
  sigma_hat = np.zeros_like(sigma)
  mu_hat[-1] = mu[-1]
  sigma_hat[-1] = sigma[-1]

  # smooth the data
  for t in reversed(range(len(data)-1)):
    sigma_pred = D@ sigma[t] @ D.T + Q  # sigma_pred at t+1
    ###########################################################################
    ## TODO for students: compute the smoothed state mean and covariance values
    # Fill out function and remove
    raise NotImplementedError("Student exercise: compute the smoothed state mean and covariance values")
    ###########################################################################

    # write the expression to compute the Kalman gain for the backward process
    J = ...
    # write the expression to compute the smoothed state mean estimate
    mu_hat[t] = ...
    # write the expression to compute the smoothed state noise covariance estimate
    sigma_hat[t] = ...

  return mu_hat, sigma_hat


# Uncomment once the kalman_smooth function is complete
# smoothed_state_means, smoothed_state_covariances = kalman_smooth(obs, params)
# axes = plot_kalman(state, obs, filtered_state_means, color="r",
#                    label="my kf-filter")
# plot_kalman(state, obs, smoothed_state_means, color="b",
#             label="my kf-smoothed", axes=axes)

# to_remove solution
def kalman_smooth(data, params):
  """ Perform Kalman smoothing (backward pass) on the data given the provided
  system parameters.

  Args:
    data (ndarray): a sequence of observations of shape(n_timesteps, n_dim_obs)
    params (dict): a dictionary of model parameters: (D, Q, H, R, mu_0, sigma_0)

  Returns:
    ndarray, ndarray: the smoothed system means and noise covariance values
  """
  # pulled out of the params dict for convenience
  D= params['D']
  Q = params['Q']
  H = params['H']
  R = params['R']

  n_dim_state = D.shape[0]
  n_dim_obs = H.shape[0]

  # first run the forward pass to get the filtered means and covariances
  mu, sigma = kalman_filter(data, params)

  # initialize state mean and covariance estimates
  mu_hat = np.zeros_like(mu)
  sigma_hat = np.zeros_like(sigma)
  mu_hat[-1] = mu[-1]
  sigma_hat[-1] = sigma[-1]

  # smooth the data
  for t in reversed(range(len(data)-1)):
    sigma_pred = D@ sigma[t] @ D.T + Q  # sigma_pred at t+1

    # write the expression to compute the Kalman gain for the backward process
    J = sigma[t] @ D.T @ np.linalg.inv(sigma_pred)
    # write the expression to compute the smoothed state mean estimate
    mu_hat[t] = mu[t] + J @ (mu_hat[t+1] - D@ mu[t])
    # write the expression to compute the smoothed state noise covariance estimate
    sigma_hat[t] = sigma[t] + J @ (sigma_hat[t+1] - sigma_pred) @ J.T

  return mu_hat, sigma_hat


smoothed_state_means, smoothed_state_covariances = kalman_smooth(obs, params)
with plt.xkcd():
  axes = plot_kalman(state, obs, filtered_state_means, color="r",
                     label="my kf-filter")
  plot_kalman(state, obs, smoothed_state_means, color="b",
              label="my kf-smoothed", axes=axes)
../../../_images/W3D2_Tutorial4_73_1.png

Forward vs Backward

Now that we have implementations for both, let’s compare their performance by computing the MSE between the filtered (forward) and smoothed (backward) estimated states and the true latent state.

print(f"Filtered MSE: {np.mean((state - filtered_state_means)**2):.3f}")
print(f"Smoothed MSE: {np.mean((state - smoothed_state_means)**2):.3f}")
Filtered MSE: 10.963
Smoothed MSE: 7.087

In this example, the smoothed estimate is clearly superior to the filtered one. This makes sense as the forward pass uses only the past measurements, whereas the backward pass can use future measurement too, correcting the forward pass estimates given all the data we’ve collected.

So why would you ever use Kalman filtering alone, without smoothing? As Kalman filtering only depends on already observed data (i.e. the past) it can be run in a streaming, or on-line, setting. Kalman smoothing relies on future data as it were, and as such can only be applied in a batch, or off-line, setting. So use Kalman filtering if you need real-time corrections and Kalman smoothing if you are considering already-collected data.

This online case is typically what the brain faces.

The Expectation-Maximization (EM) Algorithm

  • want to maximize \(\log p(m|\theta)\)

  • need to marginalize out latent state (which is not tractable)

(400)\[\begin{equation} p(m|\theta)=\int p(m,s|\theta)dz \end{equation}\]
  • add a probability distribution \(q(s)\) which will approximate the latent state distribution

\[\log p(m|\theta)\int_s q(s)dz\]
  • can be rewritten as

(401)\[\begin{equation} \mathcal{L}(q,\theta)+KL\left(q(s)||p(s|m),\theta\right) \end{equation}\]
  • \(\mathcal{L}(q,\theta)\) contains the joint distribution of \(m\) and \(s\)

  • \(KL(q||p)\) contains the conditional distribution of \(s|m\)

Expectation step

  • parameters are kept fixed

  • find a good approximation \(q(s)\): maximize lower bound \(\mathcal{L}(q,\theta)\) with respect to \(q(s)\)

  • (already implemented Kalman filter+smoother)

Maximization step

  • keep distribution \(q(s)\) fixed

  • change parameters to maximize the lower bound \(\mathcal{L}(q,\theta)\)

As mentioned, we have already effectively solved for the E-Step with our Kalman filter and smoother. The M-step requires further derivation, which is covered in the Appendix. Rather than having you implement the M-Step yourselves, let’s instead turn to using a library that has already implemented EM for exploring some experimental data from cognitive neuroscience.

The M-step for a LDS

(see Bishop, chapter 13.3.2 Learning in LDS) Update parameters of the probability distribution

For the updates in the M-step we will need the following posterior marginals obtained from the Kalman smoothing results* \(\hat{\mu}_t^{\rm smooth}, \hat{\Sigma}_t^{\rm smooth}\)

(402)\[\begin{eqnarray} E(s_t) &=& \hat{\mu}_t \\ E(s_ts_{t-1}^T) &=& J_{t-1}\hat{\Sigma}_t+\hat{\mu}_t\hat{\mu}_{t-1}^T\\ E(s_ts_{t}^T) &=& \hat{\Sigma}_t+\hat{\mu}_t\hat{\mu}_{t}^T \end{eqnarray}\]

Update parameters

Initial parameters

(403)\[\begin{eqnarray} \mu_0^{\rm new}&=& E(s_0)\\ Q_0^{\rm new} &=& E(s_0s_0^T)-E(s_0)E(s_0^T) \\ \end{eqnarray}\]

Hidden (latent) state parameters

(404)\[\begin{eqnarray} D^{\rm new} &=& \left(\sum_{t=2}^N E(s_ts_{t-1}^T)\right)\left(\sum_{t=2}^N E(s_{t-1}s_{t-1}^T)\right)^{-1} \\ Q^{\rm new} &=& \frac{1}{T-1} \sum_{t=2}^N E\big(s_ts_t^T\big) - D^{\rm new}E\big(s_{t-1}s_{t}^T\big) - E\big(s_ts_{t-1}^T\big)D^{\rm new}+D^{\rm new}E\big(s_{t-1}s_{t-1}^T\big)\big(D^{\rm new}\big)^{T}\\ \end{eqnarray}\]

Observable (measured) space parameters

(405)\[\begin{eqnarray} H^{\rm new} &=& \left(\sum_{t=1}^N y_t E(s_t^T)\right)\left(\sum_{t=1}^N E(s_t s_t^T)\right)^{-1}\\ R^{\rm new} &=& \frac{1}{T}\sum_{t=1}^Ny_ty_t^T-H^{\rm new}E(s_t)y_t^T-y_tE(s_t^T)H^{\rm new}+H^{\rm new}E(s_ts_t^T)H_{\rm new} \end{eqnarray}\]